﻿// This is an open source non-commercial project. Dear PVS-Studio, please check it.
// PVS-Studio Static Code Analyzer for C, C++ and C#: http://www.viva64.com

// ReSharper disable CheckNamespace
// ReSharper disable CommentTypo

/* NotNullableGenerator.cs -- генерирует свойство, которому нельзя присвоить null
 * Ars Magna project, http://arsmagna.ru
 */

#region Using directives

using System.Collections.Generic;
using System.Linq;
using System.Text;

using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Text;

#endregion

namespace AM.SourceGeneration
{
    /// <summary>
    /// Генерирует свойство, которому нельзя присвоить null.
    /// </summary>
    [Generator]
    public class NotNullableGenerator
        : IIncrementalGenerator
    {
        #region Constants

        private const string AttributeName = "AM.SourceGeneration.NotNullableAttribute";

        private const string AttributeText = @"// <auto-generated />
using System;

namespace AM.SourceGeneration
{
    [AttributeUsage (AttributeTargets.Field)]
    internal sealed class NotNullableAttribute: Attribute
    {
    }
}
";

        #endregion

        #region Private members

        private static bool IsSyntaxTargetForGeneration (SyntaxNode node)
            => node is FieldDeclarationSyntax m && m.AttributeLists.Count > 0;

        private static List<IFieldSymbol> GetSemanticTargetForGeneration
            (
                GeneratorSyntaxContext context
            )
        {
            var result = new List<IFieldSymbol>();
            var fieldDeclaration = (FieldDeclarationSyntax) context.Node;

            foreach (var attributeList in fieldDeclaration.AttributeLists)
            {
                foreach (var attribute in attributeList.Attributes)
                {
                    if (!(context.SemanticModel.GetSymbolInfo (attribute).Symbol is IMethodSymbol symbol))
                    {
                        continue;
                    }

                    var type = symbol.ContainingType;
                    var fullName = type.ToDisplayString();
                    if (fullName == AttributeName)
                    {
                        foreach (var variable in fieldDeclaration.Declaration.Variables)
                        {
                            if (context.SemanticModel.GetDeclaredSymbol (variable) is IFieldSymbol fieldSymbol)
                            {
                                result.Add (fieldSymbol);
                            }
                        }
                    }
                }
            }

            return result;
        }

        private static void Execute
            (
                IEnumerable<IFieldSymbol> collected,
                SourceProductionContext context
            )
        {
            var types = collected.GroupBy<IFieldSymbol, INamedTypeSymbol>
                (
                    it => it.ContainingType, SymbolEqualityComparer.Default
                );

            foreach (var group in types)
            {
                var classSource = ProcessClass (group.Key, group.ToList());
                if (!string.IsNullOrEmpty (classSource))
                {
                    context.AddSource
                        (
                            $"{group.Key.Name}_not_nullable.g.cs",
                            SourceText.From (classSource, Encoding.UTF8)
                        );
                }
            }
        }

        private static string ProcessClass
            (
                INamedTypeSymbol classSymbol,
                IList<IFieldSymbol> fields
            )
        {
            var result = new StringBuilder();
            var namespaceName = classSymbol.ContainingNamespace.ToDisplayString();

            var source = new StringBuilder (
                $@"namespace {namespaceName}
{{
    partial class {classSymbol.Name}
    {{
");

            foreach (var fieldSymbol in fields)
            {
                ProcessField (source, fieldSymbol);
            }

            source.Append ("} }");
            return result.ToString();
        }

        private static void ProcessField
            (
                StringBuilder source,
                IFieldSymbol fieldSymbol
            )
        {
            var fieldName = fieldSymbol.Name;
            var fieldType = fieldSymbol.Type;

            var propertyName = Utility.ChooseName (fieldName);
            if (propertyName.Length == 0 || propertyName == fieldName)
            {
                // TODO: issue a diagnostic that we can't process this field
                return;
            }

            source.Append
                (
                    $@"

        public {fieldType} {propertyName}
        {{
            get
            {{
                return this.{fieldName};
            }}

            set
            {{
                if ((ReferenceEquals (value, null))
                {{
                    throw new System.ArgumentNullException (nameof ({propertyName}));
                }}
                this.{fieldName} = value;
            }}
        }}
"
                );
        }

        #endregion

        #region IIncrementalGenerator members

        /// <inheritdoc cref="IIncrementalGenerator.Initialize"/>
        public void Initialize
            (
                IncrementalGeneratorInitializationContext context
            )
        {
            // объявляем маркерный атрибут
            context.RegisterPostInitializationOutput
                (
                    ctx => ctx.AddSource
                        (
                            "NotNullableAttribute.g.cs",
                            SourceText.From (AttributeText, Encoding.UTF8)
                        )
                );

            // отфильтровываем нужные поля
            var declarations = context.SyntaxProvider.CreateSyntaxProvider
                    (
                        predicate: (s, _) => IsSyntaxTargetForGeneration (s),
                        transform: (ctx, _) => GetSemanticTargetForGeneration (ctx)
                    );

            // объединяем данные вместе
            var compilation = context.CompilationProvider.Combine (declarations.Collect());

            // регистрируем метод, который занимается собственно генерацией
            context.RegisterSourceOutput
                (
                    compilation,
                    (spc, source) =>
                    {
                        var collected = source.Item2.SelectMany (it => it).ToList();
                        Execute (collected, spc);
                    });
        }

        #endregion
    }
}
